1 - Image Classification Task

  1. Train 2 different models and get results. Download the image classification data.
    • Train a classification model (using PyTorch or Tensorflow) to classify the tissue images into organ systems they come from. (Do not use a pre-trained model. You should create a model and a dataloader from scratch.)
    • Train a classification model (using Pytorch or Tensorflow) to classify the tissue images into organ systems they come from. Use a pre-trained model such as VGG, Inception, Efficientnet etc. You may use in-built functions to create your model and dataloader.
    • Calculate the training and test accuracy of your model.
  2. Visualize
    • Overlap between training and test datasets in 2D, e.g., using t-SNE, UMAP, MDS etc.
    • Prediction results
In [1]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Importing the libraries and dependencies and drive

The following cell block imports all the libraries, functions, methods that are being used inn this jupyter notebook.
The Deep Learning/Machine Learning Libarary being used here is the high level keras API of tensorflow 2.X.
And the plotting is done with matplotlib or seaborn plotting packages.

Along with this all the constants that will be common throughout the file are defined and declared here.

In [ ]:
!pip install tensorflow_io
Requirement already satisfied: tensorflow_io in /usr/local/lib/python3.7/dist-packages (0.22.0)
Requirement already satisfied: tensorflow-io-gcs-filesystem==0.22.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow_io) (0.22.0)
Requirement already satisfied: tensorflow<2.8.0,>=2.7.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow_io) (2.7.0)
Requirement already satisfied: keras-preprocessing>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.1.2)
Requirement already satisfied: gast<0.5.0,>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.4.0)
Requirement already satisfied: libclang>=9.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (12.0.0)
Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.15.0)
Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.17.3)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.42.0)
Requirement already satisfied: keras<2.8,>=2.7.0rc0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (2.7.0)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.1.0)
Requirement already satisfied: wheel<1.0,>=0.32.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.37.0)
Requirement already satisfied: tensorboard~=2.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (2.7.0)
Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.2.0)
Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.13.3)
Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.1.0)
Requirement already satisfied: tensorflow-estimator<2.8,~=2.7.0rc0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (2.7.0)
Requirement already satisfied: absl-py>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.12.0)
Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.10.0.2)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.3.0)
Requirement already satisfied: numpy>=1.14.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.19.5)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (2.0)
Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.6.3)
Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py>=2.9.0->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.5.2)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.6.1)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.4.6)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.8.0)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (2.23.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.3.6)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.0.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (57.4.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.35.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.3.0)
Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (4.8.2)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.6.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (0.4.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (2021.10.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow<2.8.0,>=2.7.0->tensorflow_io) (3.1.1)
In [ ]:
!nvidia-smi
Sun Dec 19 03:06:42 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
In [2]:
import pandas as pd
from PIL import Image, ImageDraw
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import seaborn as sns
from keras.callbacks import CSVLogger
import tensorflow_datasets as tfds
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import random, cv2, os, cv2, gc, json
# import tensorflow_io as tfio
from tensorflow.keras.applications import ResNet50
from sklearn.metrics import confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline

EPOCHS = 50
BATCH_SIZE = 16
CLASSES = 12
IMG_DIM = (3000, 3000, 3)
NEW_IMG_SIZE = (512, 512, 3)
SHUFFLE = False
NUM_LAYERS = 5
ACCURACY_THRESHOLD = 0.98
LEARNING_RATE = 0.001
SEED = 1
TRAIN = True

Loading Data

The following block of code loads data in two formats.

  • First is use the most common way - manually through numpy array.
  • Second is the TFDS - Tensorflow dataset module which loads data in the format of tenorflow_dataset whic makes it fast o process and load data onto and away from the memory(GPU/CPU)

I experimented with both the above techniques but the second one was very fast and thats what i am sticking with. As it loads the next batch of data onto GPU memory while the current batch is being operated on - this is done with the prefetching functionality of the tensirflow_dataset module.

In [5]:
path_train = "/content/drive/MyDrive/datas/Data/Image Classification Data/data/train/"
path_test = "/content/drive/MyDrive/datas/Data/Image Classification Data/data/test/"
classes = os.listdir('/content/drive/MyDrive/datas/Data/Image Classification Data/data/train')
In [ ]:
train_df = pd.DataFrame({"images":[],"labels":[]})
for cls in classes:
    images = os.listdir(path_train+cls)
    for img in range(len(images)):
        image = np.array(Image.open(path_train + cls + "/" + images[img]))
        images[img] = cv2.resize(image, (512,512)).astype(np.int16)
        del image
    temp = pd.DataFrame({"images":images,"labels":[cls]*(len(images))})
    del images
    train_df = pd.concat([train_df, temp])
    del temp
    gc.collect()

train_df.shape
In [ ]:
test_df = pd.DataFrame({"images":[],"labels":[]})
for cls in classes:
    images = os.listdir(path_test+cls)
    for img in range(len(images)):
        image = np.array(Image.open(path_test + cls + "/" + images[img]))
        images[img] = cv2.resize(image, (512,512)).astype(np.int16)
        del image
    temp = pd.DataFrame({"images":images,"labels":[cls]*(len(images))})
    del images
    test_df = pd.concat([test_df, temp])
    del temp
    gc.collect()

test_df.shape
In [3]:
train_images = np.load("/content/drive/MyDrive/datas/train_images.npy",allow_pickle=True)
train_labels = np.load("/content/drive/MyDrive/datas/train_labelss.npy",allow_pickle=True)
train_df = pd.DataFrame({"images":train_images, "labels":train_labels})

test_images = np.load("/content/drive/MyDrive/datas/test_images.npy",allow_pickle=True)
test_labels = np.load("/content/drive/MyDrive/datas/test_labels.npy",allow_pickle=True)
test_df = pd.DataFrame({"images":test_images, "labels":test_labels})

Labels, preprocessing & Look at the data

In [6]:
classes_TO_labels = {cls:i for i,cls in enumerate(classes)}
labels_TO_classes = {y:x for x, y in classes_TO_labels.items()}
train_df['labels'] = train_df['labels'].replace(classes_TO_labels)
test_df['labels'] = test_df['labels'].replace(classes_TO_labels)
classes_TO_labels
Out[6]:
{'colon': 10,
 'endometrium_1': 11,
 'endometrium_2': 8,
 'kidney': 5,
 'liver': 6,
 'lung': 2,
 'lymph_node': 7,
 'pancreas': 9,
 'skin_1': 3,
 'skin_2': 0,
 'small_intestine': 1,
 'spleen': 4}
In [39]:
def generator():
    for i in train_df.iterrows():
        # t = np.array(Image.open(i[1]['images']))
        # x = cv2.resize(t, (512,512))
        y = tf.one_hot(i[1]['labels'], depth=12)
        yield tf.cast(i[1]['images'], tf.float32)/255.0, tf.cast(y, tf.int8)


def generator_test():
    for i in test_df.iterrows():
        y = tf.one_hot(i[1]['labels'], depth=12)
        yield tf.cast(i[1]['images'], tf.float32)/255.0, tf.cast(y, tf.int8)

dataset = tf.data.Dataset.from_generator(generator, (tf.float32, tf.int8), output_shapes=((512,512,3), (12)))
dataset_test = tf.data.Dataset.from_generator(generator_test, (tf.float32, tf.int8), output_shapes=((512,512,3), (12)))

ds_train = dataset.cache()
ds_train = ds_train.shuffle(1200)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

ds_test = dataset_test.batch(BATCH_SIZE)
# ds_test = ds_test.shuffle(600)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
In [40]:
plt.figure(figsize = (10,10))
for i in ds_train.take(1):
    plt.subplot(1, 2, 1)
    plt.imshow(i[0][0])
    plt.title("Training data example")

for i in ds_test.take(1):
    plt.subplot(1, 2, 2)
    plt.imshow(i[0][0])
    plt.title("Testing data example")

Model architecture & helper methods

There are three helper methods :-

  • seed_everthing - sets the seed for numpy, tensorflow and random modules so that when we try and replicate a result we can do so if the seeds are same.
  • clear_tf - clears the backend of the tensorflow module.
  • Model - this method defines a common method which will be used commonly accross to define a Deep Learning model.
In [41]:
class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        if(logs.get('val_accuracy') > ACCURACY_THRESHOLD):   
            print("\nReached %2.2f%% validation accuracy, so stopping training!!" %(ACCURACY_THRESHOLD*100))   
            self.model.stop_training = True

checkpoint = keras.callbacks.ModelCheckpoint('', monitor='accuracy', mode='max', verbose=1, save_best_only=True)
In [42]:
def seed_everything(n):
    np.random.seed(n)
    tf.compat.v1.random.set_random_seed(n)
    random.seed(n)

def clear_tf():
    tf.keras.backend.clear_session()

def Model(inp_size = IMG_DIM,
          layers = NUM_LAYERS, 
          layers_nodes = [4, 8, 16, 8, 4], 
          dropout_include = False, 
          max_pooling_num = 2,
          strides = 2, 
          kernel_size = 3,
          classes = CLASSES, 
          dropout_ratio = [0.2]*NUM_LAYERS,
          activation_fn = keras.layers.LeakyReLU(),
          initializer = 'glorot_uniform',    # Default value
          loss_fn = "categorical_crossentropy",#tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer = keras.optimizers.Adam(learning_rate=0.0001), 
          include_normalization = True):
    
    assert layers == len(layers_nodes), "Layer size and number of nodes for each layer are not equal"
    assert layers == len(dropout_ratio), "Dropout array size and number of layers do not match"

    inp = tf.keras.layers.Input(shape = inp_size)
    # layer_x = tf.keras.layers.Flatten()(inp)
    layer_x = tf.keras.layers.Conv2D(layers_nodes[0], strides = strides, kernel_size = kernel_size, 
                                 activation = activation_fn,
                                 kernel_initializer = initializer,
                                 name = 'Conv2D_Layer_1')(inp)
    # layer_x = tf.keras.layers.MaxPooling2D(max_pooling_num)(layer_x)
    if include_normalization:
        layer_x = tf.keras.layers.BatchNormalization()(layer_x)

    if dropout_include:
        layer_x = tf.keras.layers.Dropout(dropout_ratio[0])(layer_x)

    for layer_num in range(1, layers):
        layer_x = tf.keras.layers.Conv2D(layers_nodes[layer_num], strides = strides, kernel_size = kernel_size,  
                                     activation = activation_fn,
                                     kernel_initializer = initializer,
                                     name = 'Conv2D_Layer_'+str(layer_num+1))(layer_x)
        # layer_x = tf.keras.layers.MaxPooling2D(max_pooling_num)(layer_x)
        if include_normalization:
            layer_x = tf.keras.layers.BatchNormalization()(layer_x)
        if dropout_include:
            layer_x = tf.keras.layers.Dropout(dropout_ratio[layer_num])(layer_x)
    
    layer_x = tf.keras.layers.Flatten()(layer_x)
    out = tf.keras.layers.Dense(classes, activation = "softmax", name = "Output")(layer_x)

    model = tf.keras.models.Model(inputs = inp, outputs = out)
    model.compile(loss = loss_fn, 
                  optimizer = optimizer, 
                  metrics = ['accuracy'])
    
    return model

clear_tf()
seed_everything(SEED)
Model(inp_size=(512,512,3)).summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 512, 512, 3)]     0         
                                                                 
 Conv2D_Layer_1 (Conv2D)     (None, 255, 255, 4)       112       
                                                                 
 batch_normalization (BatchN  (None, 255, 255, 4)      16        
 ormalization)                                                   
                                                                 
 Conv2D_Layer_2 (Conv2D)     (None, 127, 127, 8)       296       
                                                                 
 batch_normalization_1 (Batc  (None, 127, 127, 8)      32        
 hNormalization)                                                 
                                                                 
 Conv2D_Layer_3 (Conv2D)     (None, 63, 63, 16)        1168      
                                                                 
 batch_normalization_2 (Batc  (None, 63, 63, 16)       64        
 hNormalization)                                                 
                                                                 
 Conv2D_Layer_4 (Conv2D)     (None, 31, 31, 8)         1160      
                                                                 
 batch_normalization_3 (Batc  (None, 31, 31, 8)        32        
 hNormalization)                                                 
                                                                 
 Conv2D_Layer_5 (Conv2D)     (None, 15, 15, 4)         292       
                                                                 
 batch_normalization_4 (Batc  (None, 15, 15, 4)        16        
 hNormalization)                                                 
                                                                 
 flatten (Flatten)           (None, 900)               0         
                                                                 
 Output (Dense)              (None, 12)                10812     
                                                                 
=================================================================
Total params: 14,000
Trainable params: 13,920
Non-trainable params: 80
_________________________________________________________________

Custom data loader & generator

In [ ]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, dataset, labels, 
                 batch_size = BATCH_SIZE, 
                 dim = (256,256,3),#NEW_IMG_SIZE, 
                 n_classes = CLASSES, 
                 shuffle = SHUFFLE):
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.dataset = dataset
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.list_IDs = np.arange(self.dataset.shape[0])
        self.indexes = np.arange(len(self.list_IDs))
        # self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(self.dataset.shape[0] / self.batch_size))

    def __getitem__(self, index):

        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        
        X = np.empty((self.batch_size, *self.dim))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            im = np.array(Image.open(self.dataset[ID]))
            X[i,] = cv2.resize(im, (256,256))
            y[i] = self.labels[ID]

        return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

train_generator = DataGenerator(train_df['images'].values, 
                                train_df['labels'].values)
val_generator = DataGenerator(test_df['images'].values, 
                              test_df['labels'].values)
In [ ]:
# checkpoint1 = keras.callbacks.ModelCheckpoint('content/drive/checkpoint.h5', monitor='loss', mode='min', verbose=1, save_best_only=True)
# logger = CSVLogger('content/drive/logs.log', separator=',', append=False)

model = Model(inp_size=(512, 512, 3))
model.fit(train_generator,
            validation_data=val_generator,
            use_multiprocessing=True, 
            epochs=EPOCHS,)
Epoch 1/20
150/150 [==============================] - 1673s 11s/step - loss: 3.2451 - accuracy: 0.0983 - val_loss: 2.9240 - val_accuracy: 0.1333
Epoch 2/20
150/150 [==============================] - 2813s 19s/step - loss: 2.8484 - accuracy: 0.1142 - val_loss: 2.5242 - val_accuracy: 0.1583
Epoch 3/20
150/150 [==============================] - 2823s 19s/step - loss: 2.5317 - accuracy: 0.1633 - val_loss: 2.3660 - val_accuracy: 0.1717
Epoch 4/20
150/150 [==============================] - 2508s 17s/step - loss: 2.3357 - accuracy: 0.1992 - val_loss: 2.2654 - val_accuracy: 0.2300
Epoch 5/20
150/150 [==============================] - 2482s 17s/step - loss: 2.1284 - accuracy: 0.2483 - val_loss: 2.2415 - val_accuracy: 0.2367
Epoch 6/20
 48/150 [========>.....................] - ETA: 16:10 - loss: 1.9416 - accuracy: 0.3333
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-12-9a96b4299abe> in <module>()
      6             validation_data=val_generator,
      7             use_multiprocessing=True,
----> 8             epochs=EPOCHS,)
      9             # callbacks = [CustomCallback()])

/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     62     filtered_tb = None
     63     try:
---> 64       return fn(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)

/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1214                 _r=1):
   1215               callbacks.on_train_batch_begin(step)
-> 1216               tmp_logs = self.train_function(iterator)
   1217               if data_handler.should_sync:
   1218                 context.async_wait()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    148     filtered_tb = None
    149     try:
--> 150       return fn(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    908 
    909       with OptionalXlaContext(self._jit_compile):
--> 910         result = self._call(*args, **kwds)
    911 
    912       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    940       # In this case we have created variables on the first call, so we run the
    941       # defunned version which is guaranteed to never create variables.
--> 942       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    943     elif self._stateful_fn is not None:
    944       # Release the lock early so that multiple threads can perform the call

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   3129        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   3130     return graph_function._call_flat(
-> 3131         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   3132 
   3133   @property

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1958       # No tape is watching; skip to running the function.
   1959       return self._build_call_outputs(self._inference_function.call(
-> 1960           ctx, args, cancellation_manager=cancellation_manager))
   1961     forward_backward = self._select_forward_and_backward_functions(
   1962         args,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    601               inputs=args,
    602               attrs=attrs,
--> 603               ctx=ctx)
    604         else:
    605           outputs = execute.execute_with_cancellation(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57     ctx.ensure_initialized()
     58     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 59                                         inputs, attrs, num_outputs)
     60   except core._NotOkStatusException as e:
     61     if name is not None:

KeyboardInterrupt: 

Note:- Unfortunately due to the huge image size/resolution the training of the model with custom data loader was taking around 35 minutes per epoch so i switched gears to a custom tfds data generator.

In [ ]:
clear_tf()
model = Model(inp_size=(512, 512, 3))
model.fit(ds_train,
            validation_data=ds_test,
            # use_multiprocessing=True, 
            epochs=EPOCHS,)
Epoch 1/50
75/75 [==============================] - 3s 36ms/step - loss: 3.0955 - accuracy: 0.1017 - val_loss: 3.1939 - val_accuracy: 0.0933
Epoch 2/50
75/75 [==============================] - 2s 32ms/step - loss: 2.6254 - accuracy: 0.1675 - val_loss: 3.0042 - val_accuracy: 0.1100
Epoch 3/50
75/75 [==============================] - 2s 32ms/step - loss: 2.3560 - accuracy: 0.2117 - val_loss: 2.5629 - val_accuracy: 0.1700
Epoch 4/50
75/75 [==============================] - 2s 32ms/step - loss: 2.1891 - accuracy: 0.2542 - val_loss: 2.3464 - val_accuracy: 0.2017
Epoch 5/50
75/75 [==============================] - 2s 32ms/step - loss: 2.0710 - accuracy: 0.2975 - val_loss: 2.2321 - val_accuracy: 0.2000
Epoch 6/50
75/75 [==============================] - 2s 32ms/step - loss: 1.9896 - accuracy: 0.3200 - val_loss: 2.1560 - val_accuracy: 0.2267
Epoch 7/50
75/75 [==============================] - 2s 32ms/step - loss: 1.8979 - accuracy: 0.3500 - val_loss: 2.1214 - val_accuracy: 0.2317
Epoch 8/50
75/75 [==============================] - 2s 33ms/step - loss: 1.8352 - accuracy: 0.3783 - val_loss: 2.0860 - val_accuracy: 0.2550
Epoch 9/50
75/75 [==============================] - 2s 32ms/step - loss: 1.7735 - accuracy: 0.3908 - val_loss: 2.0452 - val_accuracy: 0.2767
Epoch 10/50
75/75 [==============================] - 2s 32ms/step - loss: 1.7334 - accuracy: 0.4025 - val_loss: 2.0249 - val_accuracy: 0.3000
Epoch 11/50
75/75 [==============================] - 2s 32ms/step - loss: 1.6895 - accuracy: 0.4225 - val_loss: 2.0159 - val_accuracy: 0.3033
Epoch 12/50
75/75 [==============================] - 2s 33ms/step - loss: 1.6186 - accuracy: 0.4767 - val_loss: 1.9955 - val_accuracy: 0.3050
Epoch 13/50
75/75 [==============================] - 2s 32ms/step - loss: 1.5976 - accuracy: 0.4708 - val_loss: 1.9796 - val_accuracy: 0.2983
Epoch 14/50
75/75 [==============================] - 2s 32ms/step - loss: 1.5497 - accuracy: 0.4875 - val_loss: 1.9896 - val_accuracy: 0.2967
Epoch 15/50
75/75 [==============================] - 2s 32ms/step - loss: 1.5153 - accuracy: 0.5025 - val_loss: 1.9696 - val_accuracy: 0.3017
Epoch 16/50
75/75 [==============================] - 2s 33ms/step - loss: 1.4719 - accuracy: 0.5167 - val_loss: 1.9528 - val_accuracy: 0.3183
Epoch 17/50
75/75 [==============================] - 2s 31ms/step - loss: 1.4414 - accuracy: 0.5350 - val_loss: 1.9434 - val_accuracy: 0.3083
Epoch 18/50
75/75 [==============================] - 2s 32ms/step - loss: 1.4209 - accuracy: 0.5458 - val_loss: 1.9540 - val_accuracy: 0.3250
Epoch 19/50
75/75 [==============================] - 2s 33ms/step - loss: 1.3644 - accuracy: 0.5567 - val_loss: 1.9717 - val_accuracy: 0.3300
Epoch 20/50
75/75 [==============================] - 2s 32ms/step - loss: 1.3297 - accuracy: 0.5825 - val_loss: 1.9590 - val_accuracy: 0.3333
Epoch 21/50
75/75 [==============================] - 2s 32ms/step - loss: 1.3062 - accuracy: 0.5892 - val_loss: 1.9340 - val_accuracy: 0.3267
Epoch 22/50
75/75 [==============================] - 3s 33ms/step - loss: 1.2682 - accuracy: 0.5867 - val_loss: 1.9388 - val_accuracy: 0.3283
Epoch 23/50
75/75 [==============================] - 2s 32ms/step - loss: 1.2270 - accuracy: 0.6108 - val_loss: 1.9592 - val_accuracy: 0.3383
Epoch 24/50
75/75 [==============================] - 2s 32ms/step - loss: 1.2182 - accuracy: 0.6267 - val_loss: 1.9490 - val_accuracy: 0.3333
Epoch 25/50
75/75 [==============================] - 2s 32ms/step - loss: 1.1657 - accuracy: 0.6383 - val_loss: 1.9617 - val_accuracy: 0.3350
Epoch 26/50
75/75 [==============================] - 2s 32ms/step - loss: 1.1350 - accuracy: 0.6575 - val_loss: 1.9531 - val_accuracy: 0.3350
Epoch 27/50
75/75 [==============================] - 2s 32ms/step - loss: 1.1083 - accuracy: 0.6600 - val_loss: 1.9765 - val_accuracy: 0.3300
Epoch 28/50
75/75 [==============================] - 2s 31ms/step - loss: 1.0747 - accuracy: 0.6892 - val_loss: 1.9626 - val_accuracy: 0.3300
Epoch 29/50
75/75 [==============================] - 2s 32ms/step - loss: 1.0402 - accuracy: 0.6892 - val_loss: 1.9616 - val_accuracy: 0.3500
Epoch 30/50
75/75 [==============================] - 2s 32ms/step - loss: 1.0170 - accuracy: 0.7183 - val_loss: 1.9743 - val_accuracy: 0.3383
Epoch 31/50
75/75 [==============================] - 3s 34ms/step - loss: 0.9694 - accuracy: 0.7267 - val_loss: 1.9757 - val_accuracy: 0.3400
Epoch 32/50
75/75 [==============================] - 2s 32ms/step - loss: 0.9398 - accuracy: 0.7425 - val_loss: 2.0110 - val_accuracy: 0.3383
Epoch 33/50
75/75 [==============================] - 2s 33ms/step - loss: 0.8974 - accuracy: 0.7708 - val_loss: 1.9942 - val_accuracy: 0.3367
Epoch 34/50
75/75 [==============================] - 2s 32ms/step - loss: 0.8906 - accuracy: 0.7550 - val_loss: 2.0038 - val_accuracy: 0.3333
Epoch 35/50
75/75 [==============================] - 2s 32ms/step - loss: 0.8513 - accuracy: 0.7658 - val_loss: 1.9889 - val_accuracy: 0.3383
Epoch 36/50
75/75 [==============================] - 2s 32ms/step - loss: 0.8221 - accuracy: 0.7833 - val_loss: 2.0166 - val_accuracy: 0.3450
Epoch 37/50
75/75 [==============================] - 2s 33ms/step - loss: 0.7910 - accuracy: 0.7942 - val_loss: 2.0511 - val_accuracy: 0.3317
Epoch 38/50
75/75 [==============================] - 2s 32ms/step - loss: 0.7646 - accuracy: 0.8092 - val_loss: 2.0073 - val_accuracy: 0.3433
Epoch 39/50
75/75 [==============================] - 2s 31ms/step - loss: 0.7299 - accuracy: 0.8300 - val_loss: 2.0251 - val_accuracy: 0.3450
Epoch 40/50
75/75 [==============================] - 2s 32ms/step - loss: 0.6897 - accuracy: 0.8425 - val_loss: 2.0152 - val_accuracy: 0.3500
Epoch 41/50
75/75 [==============================] - 3s 34ms/step - loss: 0.6626 - accuracy: 0.8483 - val_loss: 2.0153 - val_accuracy: 0.3567
Epoch 42/50
75/75 [==============================] - 2s 32ms/step - loss: 0.6490 - accuracy: 0.8492 - val_loss: 2.0211 - val_accuracy: 0.3567
Epoch 43/50
75/75 [==============================] - 2s 33ms/step - loss: 0.6234 - accuracy: 0.8633 - val_loss: 2.0786 - val_accuracy: 0.3483
Epoch 44/50
75/75 [==============================] - 2s 32ms/step - loss: 0.6003 - accuracy: 0.8775 - val_loss: 2.0339 - val_accuracy: 0.3650
Epoch 45/50
75/75 [==============================] - 2s 31ms/step - loss: 0.5793 - accuracy: 0.8825 - val_loss: 2.0583 - val_accuracy: 0.3550
Epoch 46/50
75/75 [==============================] - 2s 32ms/step - loss: 0.5476 - accuracy: 0.8833 - val_loss: 2.1119 - val_accuracy: 0.3467
Epoch 47/50
75/75 [==============================] - 2s 32ms/step - loss: 0.5330 - accuracy: 0.8967 - val_loss: 2.0789 - val_accuracy: 0.3500
Epoch 48/50
75/75 [==============================] - 2s 32ms/step - loss: 0.5160 - accuracy: 0.8950 - val_loss: 2.0711 - val_accuracy: 0.3567
Epoch 49/50
75/75 [==============================] - 2s 32ms/step - loss: 0.4853 - accuracy: 0.9142 - val_loss: 2.0830 - val_accuracy: 0.3650
Epoch 50/50
75/75 [==============================] - 2s 32ms/step - loss: 0.4650 - accuracy: 0.9242 - val_loss: 2.0991 - val_accuracy: 0.3700
Out[ ]:
<keras.callbacks.History at 0x7fdd6160e0d0>

Save and load the model

In [ ]:
if TRAIN:
    model.save('/content/drive/MyDrive/datas/my_model.h5')
else:
    model = tf.keras.models.load_model('/content/drive/MyDrive/datas/my_model.h5')
model.evaluate(ds_test)
WARNING:tensorflow:Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.
38/38 [==============================] - 1s 17ms/step - loss: 2.0991 - accuracy: 0.3700
Out[ ]:
[2.0991361141204834, 0.3700000047683716]

Model structure

In [ ]:
tf.keras.utils.plot_model(model)
Out[ ]:

Accuracy and Loss plot for custom model

In [ ]:
# summarize history for accuracy
plt.plot(model.history.history['accuracy'])
plt.plot(model.history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# summarize history for loss
plt.plot(model.history.history['loss'])
plt.plot(model.history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

Pretrained VGG

In [44]:
ACCURACY_THRESHOLD = 1
class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        # print("End epoch {} of training; got log keys: {}".format(epoch, keys))
        if(logs.get('accuracy') == ACCURACY_THRESHOLD):   
            print("\nReached %2.2f%% training accuracy, so stopping training!!" %(ACCURACY_THRESHOLD*100))   
            self.model.stop_training = True

# checkpoint = keras.callbacks.ModelCheckpoint('', monitor='accuracy', mode='max', verbose=1, save_best_only=True)
In [45]:
vgg = tf.keras.applications.VGG16(include_top=False,
                                        weights='imagenet',)
inp = tf.keras.layers.Input(shape = NEW_IMG_SIZE)
layer = tf.keras.layers.MaxPool2D(2)(inp)
layer = vgg(layer)
layer = tf.keras.layers.Flatten()(layer)
layer = tf.keras.layers.Dense(1024, activation=tf.keras.layers.LeakyReLU())(layer)
output = tf.keras.layers.Dense(12, activation = 'softmax')(layer)

my_vgg = tf.keras.models.Model(inputs = inp, outputs = output)
my_vgg.compile(loss = 'categorical_crossentropy', 
                  optimizer = tf.keras.optimizers.Adam(0.00001), 
                  metrics = ['accuracy'])
my_vgg.fit(ds_train, 
               validation_data = ds_test,
               epochs = EPOCHS, 
             callbacks = [CustomCallback()])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
58892288/58889256 [==============================] - 1s 0us/step
58900480/58889256 [==============================] - 1s 0us/step
Epoch 1/50
      6/Unknown - 5s 153ms/step - loss: 2.4138 - accuracy: 0.1667WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0620s vs `on_train_batch_end` time: 0.0761s). Check your callbacks.
75/75 [==============================] - 18s 188ms/step - loss: 1.5419 - accuracy: 0.4750 - val_loss: 0.9790 - val_accuracy: 0.6217
Epoch 2/50
75/75 [==============================] - 13s 169ms/step - loss: 0.5386 - accuracy: 0.8250 - val_loss: 0.7757 - val_accuracy: 0.7283
Epoch 3/50
75/75 [==============================] - 13s 171ms/step - loss: 0.2112 - accuracy: 0.9458 - val_loss: 0.7972 - val_accuracy: 0.7267
Epoch 4/50
75/75 [==============================] - 13s 171ms/step - loss: 0.0890 - accuracy: 0.9750 - val_loss: 0.7918 - val_accuracy: 0.7550
Epoch 5/50
75/75 [==============================] - 13s 169ms/step - loss: 0.0776 - accuracy: 0.9783 - val_loss: 0.9040 - val_accuracy: 0.7483
Epoch 6/50
75/75 [==============================] - 13s 168ms/step - loss: 0.0307 - accuracy: 0.9908 - val_loss: 0.8138 - val_accuracy: 0.7717
Epoch 7/50
75/75 [==============================] - ETA: 0s - loss: 0.0039 - accuracy: 1.0000
Reached 100.00% training accuracy, so stopping training!!
75/75 [==============================] - 13s 169ms/step - loss: 0.0039 - accuracy: 1.0000 - val_loss: 0.8846 - val_accuracy: 0.7700
Out[45]:
<keras.callbacks.History at 0x7f23983f8b90>

Test accuracy & Model weights

In [46]:
if TRAIN:
    my_vgg.save('/content/drive/MyDrive/datas/my_vgg.h5')
else:
    my_vgg = tf.keras.models.load_model('/content/drive/MyDrive/datas/my_vgg.h5')
my_vgg.evaluate(ds_test)
38/38 [==============================] - 2s 49ms/step - loss: 0.8846 - accuracy: 0.7700
Out[46]:
[0.884645402431488, 0.7699999809265137]

Visualization

Results prediction for custom model

In [61]:
matrix = confusion_matrix(test_df.labels, np.argmax(model.predict(ds_test), axis=1))
cm_df = pd.DataFrame(matrix,
                     index = [labels_TO_classes[i] for i in range(0,12)], 
                     columns = [labels_TO_classes[i] for i in range(0,12)])
plt.figure(figsize=(10,6))
sns.heatmap(cm_df, annot=True)
plt.title('Confusion Matrix for Custom Model')
plt.ylabel('Actal Values')
plt.xlabel('Predicted Values')
plt.show()

Results prediction for VGG model

In [62]:
matrix = confusion_matrix(test_df.labels, np.argmax(my_vgg.predict(ds_test), axis=1))
cm_df = pd.DataFrame(matrix,
                     index = [labels_TO_classes[i] for i in range(0,12)], 
                     columns = [labels_TO_classes[i] for i in range(0,12)])
plt.figure(figsize=(10,6))
sns.heatmap(cm_df, annot=True)
plt.title('Confusion Matrix for VGG')
plt.ylabel('Actal Values')
plt.xlabel('Predicted Values')
plt.show()

Similarity between training and test datasets

In [9]:
tsne = TSNE(n_components=2, verbose=0, 
            perplexity=40, n_iter=500, learning_rate=100)
train_tsne_results = tsne.fit_transform(np.stack([img.flatten() for img in train_df.images.values]))
test_tsne_results = tsne.fit_transform(np.stack([img.flatten() for img in test_df.images.values]))
In [26]:
train_df = pd.DataFrame({"images":[],"labels":[]})
test_df = pd.DataFrame({"images":[],"labels":[],"most_similar_to":[]})

for cls in classes:
    images_test = os.listdir(path_test+cls)
    images_train = os.listdir(path_train+cls)

    most_similar_to = [np.nan]*(len(images_test))
    temp_train = pd.DataFrame({"images":images_train,"labels":[cls]*(len(images_train))})
    temp_test = pd.DataFrame({"images":images_test,"labels":[cls]*(len(images_test)), "most_similar_to":most_similar_to})

    test_df = pd.concat([test_df, temp_test])
    train_df = pd.concat([train_df, temp_train])
In [28]:
similar_to = []
for i in range(600):
    cos = cosine_similarity(train_tsne_results, np.expand_dims(test_tsne_results[i], axis=0))
    idx = np.argmax(cos)
    similar_to.append(train_df.images.values[idx])

test_df['most_similar_to'] = similar_to
In [29]:
test_df.head()
Out[29]:
images labels most_similar_to
0 40199_166606_B_5_8.tif skin_2 40134_97237_A_5_2.tif
1 40229_88854_B_6_8.tif skin_2 40044_96798_A_7_4.tif
2 40182_140480_B_5_8.tif skin_2 40071_95769_A_8_8.tif
3 40170_89105_B_5_8.tif skin_2 40066_141607_B_2_3.tif
4 40182_140480_B_6_8.tif skin_2 40118_84737_B_8_4.tif

Model output and internal working understanding

In [91]:
def visualize_layer_PCA(layer_name, data, labels, mod):#, ax):
    layer_output = mod.get_layer(layer_name).output

    intermediate_model = tf.keras.models.Model(inputs = mod.input,
                                               outputs = layer_output)
    
    intermediate_prediction = intermediate_model.predict(data) # tf.cast(data, tf.float32)/255.0

    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(np.stack([img.flatten() for img in intermediate_prediction]))

    plt.subplot(1, 2, 1)
    sns.scatterplot(
        x = pca_result[:,0], 
        y = pca_result[:,1],
        hue=labels,
        palette = sns.color_palette("hls", 12),
        legend = "full",
        alpha = 0.3)#, ax = ax)
    plt.title(f"PCA for layer {layer_name}")
    plt.xlabel("PCA 1st Component")
    plt.ylabel("PCA 2nd Component")
    for i in range(10):
        xtext, ytext = np.median(pca_result[labels == i, :], axis=0)
        plt.text(xtext, ytext, labels_TO_classes[i], fontsize=12)
In [102]:
def visualize_layer_TSNE(layer_name, data, labels, mod,):# ax):
    layer_output = mod.get_layer(layer_name).output

    intermediate_model = tf.keras.models.Model(inputs = mod.input,
                                               outputs = layer_output)
    intermediate_prediction = intermediate_model.predict(data)

    tsne = TSNE(n_components=2, verbose=0, perplexity=20, n_iter=500)
    tsne_results = tsne.fit_transform(np.stack([img.flatten() for img in intermediate_prediction]))

    # plt.figure(figsize=(10,6))
    plt.subplot(1, 2, 2)
    sns.scatterplot(
        x = tsne_results[:,0], 
        y = tsne_results[:,1],
        hue=labels,
        palette=sns.color_palette("hls", 12),
        legend="full",
        alpha=0.3,)
        # ax = ax)
    plt.title(f"TSNE for layer {layer_name}")
    plt.xlabel("TSNE 1st Component")
    plt.ylabel("TSNE 2nd Component")
    for i in range(10):
        xtext, ytext = np.median(tsne_results[labels == i, :], axis=0)
        plt.text(xtext, ytext, labels_TO_classes[i], fontsize=12)
    # plt.show()

TSNE on train set and output for train set images

In [65]:
tsne_inp = train_df.images.values

tsne = TSNE(n_components=2, verbose=1, 
            perplexity=40, n_iter=500)
tsne_results = tsne.fit_transform(np.stack([img.flatten() for img in tsne_inp]))

plt.figure(figsize=(12,8))
sns.scatterplot(
            x = tsne_results[:,0], 
            y = tsne_results[:,1],
            hue=train_df['labels'],
            palette=sns.color_palette("hls", 12),
            legend="full",
            alpha=0.3
        )
plt.title("TSNE for Input as images itself")
plt.xlabel("TSNE 1st Component")
plt.ylabel("TSNE 2nd Component")

for idx in range(12):
    xtext, ytext = np.median(tsne_results[train_df['labels'] == idx, :], axis=0)
    plt.text(xtext, ytext, labels_TO_classes[idx], fontsize=12)

plt.show()
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.
  FutureWarning,
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.
  FutureWarning,
[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 1200 samples in 1.255s...
[t-SNE] Computed neighbors for 1200 samples in 70.553s...
[t-SNE] Computed conditional probabilities for sample 1000 / 1200
[t-SNE] Computed conditional probabilities for sample 1200 / 1200
[t-SNE] Mean sigma: 8817.663640
[t-SNE] KL divergence after 250 iterations with early exaggeration: 105.686523
[t-SNE] KL divergence after 500 iterations: 1.562698
In [103]:
for layer in model.layers:
    name = layer.name
    if 'Conv2D' in name: 
        print("PCA & T-SNE for layer : ", name)
        plt.figure(figsize=(20,10))
        visualize_layer_PCA(name, ds_train, train_df.labels.values, model)
        visualize_layer_TSNE(name, ds_train, train_df.labels.values, model)
        plt.show()
PCA & T-SNE for layer :  Conv2D_Layer_1
PCA & T-SNE for layer :  Conv2D_Layer_2
PCA & T-SNE for layer :  Conv2D_Layer_3
PCA & T-SNE for layer :  Conv2D_Layer_4
PCA & T-SNE for layer :  Conv2D_Layer_5

TSNE on test set and output for test set images

In [73]:
tsne_inp = test_df.images.values

tsne = TSNE(n_components=2, verbose=1, 
            perplexity=40, n_iter=500, learning_rate=100)
tsne_results = tsne.fit_transform(np.stack([img.flatten() for img in tsne_inp]))

plt.figure(figsize=(12,8))
sns.scatterplot(
            x = tsne_results[:,0], 
            y = tsne_results[:,1],
            hue=test_df['labels'],
            palette=sns.color_palette("hls", 12),
            legend="full",
            alpha=0.3
        )
plt.title("TSNE for Input as images itself")
plt.xlabel("TSNE 1st Component")
plt.ylabel("TSNE 2nd Component")

for idx in range(12):
    xtext, ytext = np.median(tsne_results[test_df['labels'] == idx, :], axis=0)
    plt.text(xtext, ytext, labels_TO_classes[idx], fontsize=12)

plt.show()
[t-SNE] Computing 121 nearest neighbors...
[t-SNE] Indexed 600 samples in 0.486s...
[t-SNE] Computed neighbors for 600 samples in 23.169s...
[t-SNE] Computed conditional probabilities for sample 600 / 600
[t-SNE] Mean sigma: 10352.059849
[t-SNE] KL divergence after 250 iterations with early exaggeration: 90.249748
[t-SNE] KL divergence after 500 iterations: 1.164943
In [94]:
for layer in model.layers:
    name = layer.name
    if 'Conv2D' in name: 
        print("PCA & T-SNE for layer : ", name)
        plt.figure(figsize=(20,10))
        visualize_layer_PCA(name, ds_test, test_df.labels.values, model)
        visualize_layer_TSNE(name, ds_test, test_df.labels.values, model)
        plt.show()
PCA & T-SNE for layer :  Conv2D_Layer_1
PCA & T-SNE for layer :  Conv2D_Layer_2
PCA & T-SNE for layer :  Conv2D_Layer_3
PCA & T-SNE for layer :  Conv2D_Layer_4
PCA & T-SNE for layer :  Conv2D_Layer_5

Segmentation

Some points:-

  • The image size was too large that the custom model that i built went out of memory to allocate the weights of the model for even a single image - so had to reduce the size of the imgaes a lot
  • Due to some pecularities i wasnt able to get any output as masks so did some changes to the input(cheating) like i added the annotated masks to the images to make the features which are to be masked stand out but even that didnt make any difference.

So some reasons why this happened can be:-

  • Noise in the data which seems to be there (as can be seen from the black patches in the data)
  • Less data - definitely the data was less but i tries overcoming this with data augmentation(didnt make any difference)
  • The features to be masked are the least prominent in the image
In [ ]:
encodings_train = pd.read_csv("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/train.csv")
encodings_train = encodings_train[encodings_train.id != "HandE_B005_CL_b_RGB_topright"]
encodings_test = pd.read_csv("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/test.csv")
train_df = np.zeros(((5, 4536, 4704, 3)))
test_df = np.zeros(((2, 4536, 4704, 3)))

for idx,i in enumerate(encodings_train['id']):
    train_df[idx,:] = np.array(Image.open("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/train/"+i+".tiff"))

for idx,i in enumerate(encodings_test['id']):
    test_df[idx,:] = np.array(Image.open("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/test/"+i+".tiff"))

def rle2mask(mask_rle, shape): 
    '''
    mask_rle: run-length as string format (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

train_mask = np.zeros((5, 4536, 4704) )
test_mask = np.zeros((2, 4536, 4704) )

for idx,x in enumerate(encodings_train['predicted']):
    single_image_rle = x.split()
    train_mask[idx,:] = rle2mask(single_image_rle, (4704, 4536))

for idx,x in enumerate(encodings_test['predicted']):
    single_image_rle = x.split()
    test_mask[idx,:] = rle2mask(single_image_rle, (4704, 4536))

def UNet(n_classes=1, IMG_HEIGHT=4536, IMG_WIDTH=4704, IMG_CHANNELS=3):
    inputs = tf.keras.layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    s = inputs

    # Encoder
    c1 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
    c1 = tf.keras.layers.Dropout(0.2)(c1)  
    c1 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)
    
    c2 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = tf.keras.layers.Dropout(0.2)(c2)  
    c2 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)
     
    c3 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = tf.keras.layers.Dropout(0.2)(c3)
    c3 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)
     
    c4 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = tf.keras.layers.Dropout(0.2)(c4)
    c4 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c4)
     
    # Transfer Block
    c5 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = tf.keras.layers.Dropout(0.3)(c5)
    c5 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
    
    # Decoder 
    u6 = tf.keras.layers.Conv2DTranspose(32, (3, 2), strides=(2, 2))(c5)
    u6 = tf.keras.layers.concatenate([u6, c4])
    c6 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = tf.keras.layers.Dropout(0.2)(c6)
    c6 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
     
    u7 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = tf.keras.layers.concatenate([u7, c3])
    c7 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = tf.keras.layers.Dropout(0.2)(c7)
    c7 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
     
    u8 = tf.keras.layers.Conv2DTranspose(8, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = tf.keras.layers.concatenate([u8, c2])
    c8 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = tf.keras.layers.Dropout(0.2)(c8)  
    c8 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)
     
    u9 = tf.keras.layers.Conv2DTranspose(4, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1], axis=3)
    c9 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = tf.keras.layers.Dropout(0.2)(c9)  
    c9 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
     
    outputs = tf.keras.layers.Conv2D(n_classes, (1, 1), activation='softmax')(c9)
     
    model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
    return model

UNet = UNet()
UNet.compile(optimizer='adam',
              loss='mse',
              metrics=['accuracy'])

UNet.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 4536, 4704,  0           []                               
                                 3)]                                                              
                                                                                                  
 conv2d (Conv2D)                (None, 4536, 4704,   112         ['input_1[0][0]']                
                                4)                                                                
                                                                                                  
 dropout (Dropout)              (None, 4536, 4704,   0           ['conv2d[0][0]']                 
                                4)                                                                
                                                                                                  
 conv2d_1 (Conv2D)              (None, 4536, 4704,   148         ['dropout[0][0]']                
                                4)                                                                
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 2268, 2352,   0           ['conv2d_1[0][0]']               
                                4)                                                                
                                                                                                  
 conv2d_2 (Conv2D)              (None, 2268, 2352,   296         ['max_pooling2d[0][0]']          
                                8)                                                                
                                                                                                  
 dropout_1 (Dropout)            (None, 2268, 2352,   0           ['conv2d_2[0][0]']               
                                8)                                                                
                                                                                                  
 conv2d_3 (Conv2D)              (None, 2268, 2352,   584         ['dropout_1[0][0]']              
                                8)                                                                
                                                                                                  
 max_pooling2d_1 (MaxPooling2D)  (None, 1134, 1176,   0          ['conv2d_3[0][0]']               
                                8)                                                                
                                                                                                  
 conv2d_4 (Conv2D)              (None, 1134, 1176,   1168        ['max_pooling2d_1[0][0]']        
                                16)                                                               
                                                                                                  
 dropout_2 (Dropout)            (None, 1134, 1176,   0           ['conv2d_4[0][0]']               
                                16)                                                               
                                                                                                  
 conv2d_5 (Conv2D)              (None, 1134, 1176,   2320        ['dropout_2[0][0]']              
                                16)                                                               
                                                                                                  
 max_pooling2d_2 (MaxPooling2D)  (None, 567, 588, 16  0          ['conv2d_5[0][0]']               
                                )                                                                 
                                                                                                  
 conv2d_6 (Conv2D)              (None, 567, 588, 32  4640        ['max_pooling2d_2[0][0]']        
                                )                                                                 
                                                                                                  
 dropout_3 (Dropout)            (None, 567, 588, 32  0           ['conv2d_6[0][0]']               
                                )                                                                 
                                                                                                  
 conv2d_7 (Conv2D)              (None, 567, 588, 32  9248        ['dropout_3[0][0]']              
                                )                                                                 
                                                                                                  
 max_pooling2d_3 (MaxPooling2D)  (None, 283, 294, 32  0          ['conv2d_7[0][0]']               
                                )                                                                 
                                                                                                  
 conv2d_8 (Conv2D)              (None, 283, 294, 64  18496       ['max_pooling2d_3[0][0]']        
                                )                                                                 
                                                                                                  
 dropout_4 (Dropout)            (None, 283, 294, 64  0           ['conv2d_8[0][0]']               
                                )                                                                 
                                                                                                  
 conv2d_9 (Conv2D)              (None, 283, 294, 64  36928       ['dropout_4[0][0]']              
                                )                                                                 
                                                                                                  
 conv2d_transpose (Conv2DTransp  (None, 567, 588, 32  12320      ['conv2d_9[0][0]']               
 ose)                           )                                                                 
                                                                                                  
 concatenate (Concatenate)      (None, 567, 588, 64  0           ['conv2d_transpose[0][0]',       
                                )                                 'conv2d_7[0][0]']               
                                                                                                  
 conv2d_10 (Conv2D)             (None, 567, 588, 32  18464       ['concatenate[0][0]']            
                                )                                                                 
                                                                                                  
 dropout_5 (Dropout)            (None, 567, 588, 32  0           ['conv2d_10[0][0]']              
                                )                                                                 
                                                                                                  
 conv2d_11 (Conv2D)             (None, 567, 588, 32  9248        ['dropout_5[0][0]']              
                                )                                                                 
                                                                                                  
 conv2d_transpose_1 (Conv2DTran  (None, 1134, 1176,   2064       ['conv2d_11[0][0]']              
 spose)                         16)                                                               
                                                                                                  
 concatenate_1 (Concatenate)    (None, 1134, 1176,   0           ['conv2d_transpose_1[0][0]',     
                                32)                               'conv2d_5[0][0]']               
                                                                                                  
 conv2d_12 (Conv2D)             (None, 1134, 1176,   4624        ['concatenate_1[0][0]']          
                                16)                                                               
                                                                                                  
 dropout_6 (Dropout)            (None, 1134, 1176,   0           ['conv2d_12[0][0]']              
                                16)                                                               
                                                                                                  
 conv2d_13 (Conv2D)             (None, 1134, 1176,   2320        ['dropout_6[0][0]']              
                                16)                                                               
                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 2268, 2352,   520        ['conv2d_13[0][0]']              
 spose)                         8)                                                                
                                                                                                  
 concatenate_2 (Concatenate)    (None, 2268, 2352,   0           ['conv2d_transpose_2[0][0]',     
                                16)                               'conv2d_3[0][0]']               
                                                                                                  
 conv2d_14 (Conv2D)             (None, 2268, 2352,   1160        ['concatenate_2[0][0]']          
                                8)                                                                
                                                                                                  
 dropout_7 (Dropout)            (None, 2268, 2352,   0           ['conv2d_14[0][0]']              
                                8)                                                                
                                                                                                  
 conv2d_15 (Conv2D)             (None, 2268, 2352,   584         ['dropout_7[0][0]']              
                                8)                                                                
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, 4536, 4704,   132        ['conv2d_15[0][0]']              
 spose)                         4)                                                                
                                                                                                  
 concatenate_3 (Concatenate)    (None, 4536, 4704,   0           ['conv2d_transpose_3[0][0]',     
                                8)                                'conv2d_1[0][0]']               
                                                                                                  
 conv2d_16 (Conv2D)             (None, 4536, 4704,   292         ['concatenate_3[0][0]']          
                                4)                                                                
                                                                                                  
 dropout_8 (Dropout)            (None, 4536, 4704,   0           ['conv2d_16[0][0]']              
                                4)                                                                
                                                                                                  
 conv2d_17 (Conv2D)             (None, 4536, 4704,   148         ['dropout_8[0][0]']              
                                4)                                                                
                                                                                                  
 conv2d_18 (Conv2D)             (None, 4536, 4704,   5           ['conv2d_17[0][0]']              
                                1)                                                                
                                                                                                  
==================================================================================================
Total params: 125,821
Trainable params: 125,821
Non-trainable params: 0
__________________________________________________________________________________________________
In [ ]:
UNet.fit(train_df, train_mask, validation_data=(test_df, test_mask), epochs=10)
Epoch 1/10
---------------------------------------------------------------------------
ResourceExhaustedError                    Traceback (most recent call last)
<ipython-input-6-f56e3a95f9e5> in <module>()
----> 1 multi_unet_model.fit(train_df, train_mask, validation_data=(test_df, test_mask), epochs=10)

/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57     ctx.ensure_initialized()
     58     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 59                                         inputs, attrs, num_outputs)
     60   except core._NotOkStatusException as e:
     61     if name is not None:

ResourceExhaustedError:  failed to allocate memory
	 [[node gradient_tape/model/dropout_2/dropout/Mul
 (defined at /usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/optimizer_v2.py:464)
]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_3154]

Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/model/dropout_2/dropout/Mul:
In[0] gradient_tape/model/conv2d_5/Conv2D/Conv2DBackpropInput:	
In[1] model/dropout_2/dropout/Cast (defined at /usr/local/lib/python3.7/dist-packages/keras/layers/core/dropout.py:109)

Operation defined at: (most recent call last)
>>>   File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
>>>     "__main__", mod_spec)
>>> 
>>>   File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
>>>     exec(code, run_globals)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
>>>     app.launch_new_instance()
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
>>>     app.start()
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
>>>     self.io_loop.start()
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
>>>     self.asyncio_loop.run_forever()
>>> 
>>>   File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
>>>     self._run_once()
>>> 
>>>   File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
>>>     handle._run()
>>> 
>>>   File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
>>>     self._context.run(self._callback, *self._args)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
>>>     handler_func(fileobj, events)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 452, in _handle_events
>>>     self._handle_recv()
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 481, in _handle_recv
>>>     self._run_callback(callback, msg)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 431, in _run_callback
>>>     callback(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
>>>     return self.dispatch_shell(stream, msg)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
>>>     handler(stream, idents, msg)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
>>>     user_expressions, allow_stdin)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
>>>     res = shell.run_cell(code, store_history=store_history, silent=silent)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
>>>     return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
>>>     interactivity=interactivity, compiler=compiler, result=result)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2828, in run_ast_nodes
>>>     if self.run_code(code, result):
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
>>>     exec(code_obj, self.user_global_ns, self.user_ns)
>>> 
>>>   File "<ipython-input-6-f56e3a95f9e5>", line 1, in <module>
>>>     multi_unet_model.fit(train_df, train_mask, validation_data=(test_df, test_mask), epochs=10)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
>>>     return fn(*args, **kwargs)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1216, in fit
>>>     tmp_logs = self.train_function(iterator)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 878, in train_function
>>>     return step_function(self, iterator)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 867, in step_function
>>>     outputs = model.distribute_strategy.run(run_step, args=(data,))
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in run_step
>>>     outputs = model.train_step(data)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 816, in train_step
>>>     self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/optimizer_v2.py", line 531, in minimize
>>>     loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/optimizer_v2.py", line 583, in _compute_gradients
>>>     grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
>>> 
>>>   File "/usr/local/lib/python3.7/dist-packages/keras/optimizer_v2/optimizer_v2.py", line 464, in _get_gradients
>>>     grads = tape.gradient(loss, var_list, grad_loss)
>>> 
In [35]:
encodings_train2 = pd.read_csv("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/train.csv")
encodings_train2 = encodings_train2[encodings_train2.id != "HandE_B005_CL_b_RGB_topright"]
encodings_test2 = pd.read_csv("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/test.csv")
train_df2 = np.zeros(((5, 512, 512, 3)), dtype = np.float32)
test_df2 = np.zeros(((2, 512, 512, 3)), dtype = np.float32)

for idx,i in enumerate(encodings_train2['id']):
    json_filename = "/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/train/"+i+'.json'
    read_file = open(json_filename, "r") 
    data = json.load(read_file)

    polys = []
    for index in range(data.__len__()):
        geom = np.array(data[index]['geometry']['coordinates'])
        polys.append(geom.astype(int))

    img = Image.open("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/train/"+i+".tiff")
    for i in range(len(polys)):
        poly = polys[i]
        ImageDraw.Draw(img).polygon(tuple(map(tuple, poly[0])), outline=1, fill=1) 

    train_df2[idx,:] = cv2.resize(np.array(img), (512, 512)).astype(np.float32)/255.0

for idx,i in enumerate(encodings_test2['id']):
    json_filename = "/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/test/"+i+'.json'
    read_file = open(json_filename, "r") 
    data = json.load(read_file)

    polys = []
    for index in range(data.__len__()):
        geom = np.array(data[index]['geometry']['coordinates'])
        polys.append(geom.astype(int))

    img = Image.open("/content/drive/MyDrive/hiring_assignment/Colonic_crypt_dataset/test/"+i+".tiff")
    for i in range(len(polys)):
        poly = polys[i]
        ImageDraw.Draw(img).polygon(tuple(map(tuple, poly[0])), outline=1, fill=1) 

    test_df2[idx,:] = cv2.resize(np.array(img), (512, 512)).astype(np.float32)/255.0
In [42]:
def rle2mask(mask_rle, shape): 
    '''
    mask_rle: run-length as string format (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return cv2.resize(img.reshape(shape).T, (512, 512))

train_mask2 = np.zeros((5, 512, 512) )
test_mask2 = np.zeros((2, 512, 512) )

for idx,x in enumerate(encodings_train2['predicted']):
    single_image_rle = x.split()
    train_mask2[idx,:] = rle2mask(single_image_rle, (4704, 4536))

for idx,x in enumerate(encodings_test2['predicted']):
    single_image_rle = x.split()
    test_mask2[idx,:] = rle2mask(single_image_rle, (4704, 4536))
In [37]:
plt.figure(figsize=(20,20))
for idx, i in enumerate(train_df2):
    plt.subplot(1, 5, idx+1)
    plt.imshow(i)
    plt.axis("off")
In [ ]:
plt.figure(figsize=(20,20))
for idx, i in enumerate(train_df2):
    plt.subplot(1, 5, idx+1)
    plt.imshow(i)
    plt.axis("off")
In [38]:
!pip install git+https://github.com/tensorflow/examples.git
from tensorflow_examples.models.pix2pix import pix2pix
Collecting git+https://github.com/tensorflow/examples.git
  Cloning https://github.com/tensorflow/examples.git to /tmp/pip-req-build-vtae2qrb
  Running command git clone -q https://github.com/tensorflow/examples.git /tmp/pip-req-build-vtae2qrb
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from tensorflow-examples===6a5c4df82b032e1ee1e5095e6f9baeb732b294db-) (0.12.0)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from tensorflow-examples===6a5c4df82b032e1ee1e5095e6f9baeb732b294db-) (1.15.0)
Building wheels for collected packages: tensorflow-examples
  Building wheel for tensorflow-examples (setup.py) ... done
  Created wheel for tensorflow-examples: filename=tensorflow_examples-6a5c4df82b032e1ee1e5095e6f9baeb732b294db_-py3-none-any.whl size=268414 sha256=f732e149c40e736f247305dc7a4a3ae4d8e828b3e2306b5fd829ecfec83cf525
  Stored in directory: /tmp/pip-ephem-wheel-cache-egnqjh04/wheels/eb/19/50/2a4363c831fa12b400af86325a6f26ade5d2cdc5b406d552ca
  WARNING: Built wheel for tensorflow-examples is invalid: Metadata 1.2 mandates PEP 440 version, but '6a5c4df82b032e1ee1e5095e6f9baeb732b294db-' is not
Failed to build tensorflow-examples
Installing collected packages: tensorflow-examples
    Running setup.py install for tensorflow-examples ... done
  DEPRECATION: tensorflow-examples was installed using the legacy 'setup.py install' method, because a wheel could not be built for it. A possible replacement is to fix the wheel build issue reported above. You can find discussion regarding this at https://github.com/pypa/pip/issues/8368.
Successfully installed tensorflow-examples-6a5c4df82b032e1ee1e5095e6f9baeb732b294db-
In [39]:
from keras import backend as K
def dice_coef(y_true, y_pred, smooth=1):
    a = K.flatten(y_true)
    b = K.flatten(y_pred)
    intersection = K.sum(a * b)
    return (2. * intersection + smooth)/(K.sum(a) + K.sum(b) + smooth)
In [58]:
base_model = tf.keras.applications.MobileNetV2(input_shape=[512, 512, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[512, 512, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,activation='softmax',
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

model = unet_model(output_channels=1)
model.compile(optimizer=tf.keras.optimizers.Adam(0.0000000001),
              loss=dice_coef,
              metrics=['accuracy'])

model.summary()
WARNING:tensorflow:`input_shape` is undefined or non-square, or `rows` is not in [96, 128, 160, 192, 224]. Weights for input shape (224, 224) will be loaded as the default.
Model: "model_8"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_9 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 model_7 (Functional)           [(None, 256, 256, 9  1841984     ['input_9[0][0]']                
                                6),                                                               
                                 (None, 128, 128, 1                                               
                                44),                                                              
                                 (None, 64, 64, 192                                               
                                ),                                                                
                                 (None, 32, 32, 576                                               
                                ),                                                                
                                 (None, 16, 16, 320                                               
                                )]                                                                
                                                                                                  
 sequential_8 (Sequential)      (None, 32, 32, 512)  1476608     ['model_7[0][4]']                
                                                                                                  
 concatenate_20 (Concatenate)   (None, 32, 32, 1088  0           ['sequential_8[0][0]',           
                                )                                 'model_7[0][3]']                
                                                                                                  
 sequential_9 (Sequential)      (None, 64, 64, 256)  2507776     ['concatenate_20[0][0]']         
                                                                                                  
 concatenate_21 (Concatenate)   (None, 64, 64, 448)  0           ['sequential_9[0][0]',           
                                                                  'model_7[0][2]']                
                                                                                                  
 sequential_10 (Sequential)     (None, 128, 128, 12  516608      ['concatenate_21[0][0]']         
                                8)                                                                
                                                                                                  
 concatenate_22 (Concatenate)   (None, 128, 128, 27  0           ['sequential_10[0][0]',          
                                2)                                'model_7[0][1]']                
                                                                                                  
 sequential_11 (Sequential)     (None, 256, 256, 64  156928      ['concatenate_22[0][0]']         
                                )                                                                 
                                                                                                  
 concatenate_23 (Concatenate)   (None, 256, 256, 16  0           ['sequential_11[0][0]',          
                                0)                                'model_7[0][0]']                
                                                                                                  
 conv2d_transpose_26 (Conv2DTra  (None, 512, 512, 1)  1441       ['concatenate_23[0][0]']         
 nspose)                                                                                          
                                                                                                  
==================================================================================================
Total params: 6,501,345
Trainable params: 4,657,441
Non-trainable params: 1,843,904
__________________________________________________________________________________________________
In [59]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        rotation_range=20,vertical_flip=True)

train_data = train_datagen.flow(train_df2, train_mask2, batch_size=2)
test_data = train_datagen.flow(test_df2, test_mask2, batch_size=2)

model.fit(train_data, validation_data=test_data, epochs=50)
Epoch 1/50
3/3 [==============================] - 5s 748ms/step - loss: 0.1000 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 2/50
3/3 [==============================] - 1s 177ms/step - loss: 0.1000 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 3/50
3/3 [==============================] - 1s 211ms/step - loss: 0.1002 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 4/50
3/3 [==============================] - 1s 180ms/step - loss: 0.1003 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 5/50
3/3 [==============================] - 1s 210ms/step - loss: 0.1000 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 6/50
3/3 [==============================] - 1s 171ms/step - loss: 0.1008 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 7/50
3/3 [==============================] - 1s 175ms/step - loss: 0.1000 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 8/50
3/3 [==============================] - 1s 177ms/step - loss: 0.1002 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 9/50
3/3 [==============================] - 0s 171ms/step - loss: 0.1000 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 10/50
3/3 [==============================] - 1s 216ms/step - loss: 0.1005 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 11/50
3/3 [==============================] - 0s 171ms/step - loss: 0.1006 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 12/50
3/3 [==============================] - 1s 204ms/step - loss: 0.1004 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 13/50
3/3 [==============================] - 1s 172ms/step - loss: 0.1010 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 14/50
3/3 [==============================] - 1s 168ms/step - loss: 0.1009 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 15/50
3/3 [==============================] - 1s 181ms/step - loss: 0.1009 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 16/50
3/3 [==============================] - 1s 180ms/step - loss: 0.1001 - accuracy: 0.0532 - val_loss: 0.1401 - val_accuracy: 0.0753
Epoch 17/50
3/3 [==============================] - ETA: 0s - loss: 0.1007 - accuracy: 0.0532
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-59-1dff1039f736> in <module>()
      8 test_data = train_datagen.flow(test_df2, test_mask2, batch_size=2)
      9 
---> 10 model.fit(train_data, validation_data=test_data, epochs=50)

/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     62     filtered_tb = None
     63     try:
---> 64       return fn(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)

/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1261               use_multiprocessing=use_multiprocessing,
   1262               return_dict=True,
-> 1263               _use_cached_eval_dataset=True)
   1264           val_logs = {'val_' + name: val for name, val in val_logs.items()}
   1265           epoch_logs.update(val_logs)

/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     62     filtered_tb = None
     63     try:
---> 64       return fn(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)

/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in evaluate(self, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, return_dict, **kwargs)
   1535             with tf.profiler.experimental.Trace('test', step_num=step, _r=1):
   1536               callbacks.on_test_batch_begin(step)
-> 1537               tmp_logs = self.test_function(iterator)
   1538               if data_handler.should_sync:
   1539                 context.async_wait()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    148     filtered_tb = None
    149     try:
--> 150       return fn(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    908 
    909       with OptionalXlaContext(self._jit_compile):
--> 910         result = self._call(*args, **kwds)
    911 
    912       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    947       # In this case we have not created variables on the first call. So we can
    948       # run the first trace but we should fail if variables are created.
--> 949       results = self._stateful_fn(*args, **kwds)
    950       if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION:
    951         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   3129        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   3130     return graph_function._call_flat(
-> 3131         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   3132 
   3133   @property

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1958       # No tape is watching; skip to running the function.
   1959       return self._build_call_outputs(self._inference_function.call(
-> 1960           ctx, args, cancellation_manager=cancellation_manager))
   1961     forward_backward = self._select_forward_and_backward_functions(
   1962         args,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    601               inputs=args,
    602               attrs=attrs,
--> 603               ctx=ctx)
    604         else:
    605           outputs = execute.execute_with_cancellation(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57     ctx.ensure_initialized()
     58     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 59                                         inputs, attrs, num_outputs)
     60   except core._NotOkStatusException as e:
     61     if name is not None:

KeyboardInterrupt: 
In [60]:
plt.imshow(np.squeeze(model.predict(tf.expand_dims(train_df2[0], 0))))
Out[60]:
<matplotlib.image.AxesImage at 0x7f40cc183890>
In [51]:
def UNet2(n_classes=1, IMG_HEIGHT=4536, IMG_WIDTH=4704, IMG_CHANNELS=3):
    inputs = tf.keras.layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    s = inputs

    # Encoder
    c1 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
    c1 = tf.keras.layers.Dropout(0.2)(c1)  
    c1 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)
    
    c2 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = tf.keras.layers.Dropout(0.2)(c2)  
    c2 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)
     
    c3 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = tf.keras.layers.Dropout(0.2)(c3)
    c3 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)
     
    c4 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = tf.keras.layers.Dropout(0.2)(c4)
    c4 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c4)
     
    # Transfer Block
    c5 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = tf.keras.layers.Dropout(0.3)(c5)
    c5 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
    
    # Decoder
    u6 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = tf.keras.layers.concatenate([u6, c4])
    c6 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = tf.keras.layers.Dropout(0.2)(c6)
    c6 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
     
    u7 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = tf.keras.layers.concatenate([u7, c3])
    c7 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = tf.keras.layers.Dropout(0.2)(c7)
    c7 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
     
    u8 = tf.keras.layers.Conv2DTranspose(8, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = tf.keras.layers.concatenate([u8, c2])
    c8 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = tf.keras.layers.Dropout(0.2)(c8)  
    c8 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)
     
    u9 = tf.keras.layers.Conv2DTranspose(4, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1], axis=3)
    c9 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = tf.keras.layers.Dropout(0.2)(c9)  
    c9 = tf.keras.layers.Conv2D(4, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
     
    outputs = tf.squeeze(tf.keras.layers.Conv2D(n_classes, (1, 1), activation='softmax')(c9))
     
    model = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
    return model

model = UNet2(1, 512, 512)
model.compile(optimizer=tf.keras.optimizers.Adam(0.00001),
              loss=dice_coef,)

model.summary()
Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_7 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_38 (Conv2D)             (None, 512, 512, 4)  112         ['input_7[0][0]']                
                                                                                                  
 dropout_18 (Dropout)           (None, 512, 512, 4)  0           ['conv2d_38[0][0]']              
                                                                                                  
 conv2d_39 (Conv2D)             (None, 512, 512, 4)  148         ['dropout_18[0][0]']             
                                                                                                  
 max_pooling2d_8 (MaxPooling2D)  (None, 256, 256, 4)  0          ['conv2d_39[0][0]']              
                                                                                                  
 conv2d_40 (Conv2D)             (None, 256, 256, 8)  296         ['max_pooling2d_8[0][0]']        
                                                                                                  
 dropout_19 (Dropout)           (None, 256, 256, 8)  0           ['conv2d_40[0][0]']              
                                                                                                  
 conv2d_41 (Conv2D)             (None, 256, 256, 8)  584         ['dropout_19[0][0]']             
                                                                                                  
 max_pooling2d_9 (MaxPooling2D)  (None, 128, 128, 8)  0          ['conv2d_41[0][0]']              
                                                                                                  
 conv2d_42 (Conv2D)             (None, 128, 128, 16  1168        ['max_pooling2d_9[0][0]']        
                                )                                                                 
                                                                                                  
 dropout_20 (Dropout)           (None, 128, 128, 16  0           ['conv2d_42[0][0]']              
                                )                                                                 
                                                                                                  
 conv2d_43 (Conv2D)             (None, 128, 128, 16  2320        ['dropout_20[0][0]']             
                                )                                                                 
                                                                                                  
 max_pooling2d_10 (MaxPooling2D  (None, 64, 64, 16)  0           ['conv2d_43[0][0]']              
 )                                                                                                
                                                                                                  
 conv2d_44 (Conv2D)             (None, 64, 64, 32)   4640        ['max_pooling2d_10[0][0]']       
                                                                                                  
 dropout_21 (Dropout)           (None, 64, 64, 32)   0           ['conv2d_44[0][0]']              
                                                                                                  
 conv2d_45 (Conv2D)             (None, 64, 64, 32)   9248        ['dropout_21[0][0]']             
                                                                                                  
 max_pooling2d_11 (MaxPooling2D  (None, 32, 32, 32)  0           ['conv2d_45[0][0]']              
 )                                                                                                
                                                                                                  
 conv2d_46 (Conv2D)             (None, 32, 32, 64)   18496       ['max_pooling2d_11[0][0]']       
                                                                                                  
 dropout_22 (Dropout)           (None, 32, 32, 64)   0           ['conv2d_46[0][0]']              
                                                                                                  
 conv2d_47 (Conv2D)             (None, 32, 32, 64)   36928       ['dropout_22[0][0]']             
                                                                                                  
 conv2d_transpose_18 (Conv2DTra  (None, 64, 64, 32)  8224        ['conv2d_47[0][0]']              
 nspose)                                                                                          
                                                                                                  
 concatenate_16 (Concatenate)   (None, 64, 64, 64)   0           ['conv2d_transpose_18[0][0]',    
                                                                  'conv2d_45[0][0]']              
                                                                                                  
 conv2d_48 (Conv2D)             (None, 64, 64, 32)   18464       ['concatenate_16[0][0]']         
                                                                                                  
 dropout_23 (Dropout)           (None, 64, 64, 32)   0           ['conv2d_48[0][0]']              
                                                                                                  
 conv2d_49 (Conv2D)             (None, 64, 64, 32)   9248        ['dropout_23[0][0]']             
                                                                                                  
 conv2d_transpose_19 (Conv2DTra  (None, 128, 128, 16  2064       ['conv2d_49[0][0]']              
 nspose)                        )                                                                 
                                                                                                  
 concatenate_17 (Concatenate)   (None, 128, 128, 32  0           ['conv2d_transpose_19[0][0]',    
                                )                                 'conv2d_43[0][0]']              
                                                                                                  
 conv2d_50 (Conv2D)             (None, 128, 128, 16  4624        ['concatenate_17[0][0]']         
                                )                                                                 
                                                                                                  
 dropout_24 (Dropout)           (None, 128, 128, 16  0           ['conv2d_50[0][0]']              
                                )                                                                 
                                                                                                  
 conv2d_51 (Conv2D)             (None, 128, 128, 16  2320        ['dropout_24[0][0]']             
                                )                                                                 
                                                                                                  
 conv2d_transpose_20 (Conv2DTra  (None, 256, 256, 8)  520        ['conv2d_51[0][0]']              
 nspose)                                                                                          
                                                                                                  
 concatenate_18 (Concatenate)   (None, 256, 256, 16  0           ['conv2d_transpose_20[0][0]',    
                                )                                 'conv2d_41[0][0]']              
                                                                                                  
 conv2d_52 (Conv2D)             (None, 256, 256, 8)  1160        ['concatenate_18[0][0]']         
                                                                                                  
 dropout_25 (Dropout)           (None, 256, 256, 8)  0           ['conv2d_52[0][0]']              
                                                                                                  
 conv2d_53 (Conv2D)             (None, 256, 256, 8)  584         ['dropout_25[0][0]']             
                                                                                                  
 conv2d_transpose_21 (Conv2DTra  (None, 512, 512, 4)  132        ['conv2d_53[0][0]']              
 nspose)                                                                                          
                                                                                                  
 concatenate_19 (Concatenate)   (None, 512, 512, 8)  0           ['conv2d_transpose_21[0][0]',    
                                                                  'conv2d_39[0][0]']              
                                                                                                  
 conv2d_54 (Conv2D)             (None, 512, 512, 4)  292         ['concatenate_19[0][0]']         
                                                                                                  
 dropout_26 (Dropout)           (None, 512, 512, 4)  0           ['conv2d_54[0][0]']              
                                                                                                  
 conv2d_55 (Conv2D)             (None, 512, 512, 4)  148         ['dropout_26[0][0]']             
                                                                                                  
 conv2d_56 (Conv2D)             (None, 512, 512, 1)  5           ['conv2d_55[0][0]']              
                                                                                                  
 tf.compat.v1.squeeze_1 (TFOpLa  None                0           ['conv2d_56[0][0]']              
 mbda)                                                                                            
                                                                                                  
==================================================================================================
Total params: 121,725
Trainable params: 121,725
Non-trainable params: 0
__________________________________________________________________________________________________
In [57]:
model.fit(train_df2, train_mask2, validation_data=(test_df2, test_mask2), epochs=10, batch_size=2, class_weight= {0:100, 1:1})
Epoch 1/10
3/3 [==============================] - 0s 64ms/step - loss: 0.1007 - val_loss: 0.1401
Epoch 2/10
3/3 [==============================] - 0s 49ms/step - loss: 0.1000 - val_loss: 0.1401
Epoch 3/10
3/3 [==============================] - 0s 53ms/step - loss: 0.1010 - val_loss: 0.1401
Epoch 4/10
3/3 [==============================] - 0s 51ms/step - loss: 0.1009 - val_loss: 0.1401
Epoch 5/10
3/3 [==============================] - 0s 58ms/step - loss: 0.1005 - val_loss: 0.1401
Epoch 6/10
3/3 [==============================] - 0s 44ms/step - loss: 0.1007 - val_loss: 0.1401
Epoch 7/10
3/3 [==============================] - 0s 51ms/step - loss: 0.1001 - val_loss: 0.1401
Epoch 8/10
3/3 [==============================] - 0s 51ms/step - loss: 0.1009 - val_loss: 0.1401
Epoch 9/10
3/3 [==============================] - 0s 46ms/step - loss: 0.1005 - val_loss: 0.1401
Epoch 10/10
3/3 [==============================] - 0s 49ms/step - loss: 0.1005 - val_loss: 0.1401
Out[57]:
<keras.callbacks.History at 0x7f3e2c015d50>
In [54]:
plt.imshow(np.squeeze(model.predict(tf.expand_dims(train_df2[0], 0))))
Out[54]:
<matplotlib.image.AxesImage at 0x7f3e2c0dae50>